{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#  第3章 k近邻法"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "1．$k$近邻法是基本且简单的分类与回归方法。$k$近邻法的基本做法是：对给定的训练实例点和输入实例点，首先确定输入实例点的$k$个最近邻训练实例点，然后利用这$k$个训练实例点的类的多数来预测输入实例点的类。\n",
    "\n",
    "2．$k$近邻模型对应于基于训练数据集对特征空间的一个划分。$k$近邻法中，当训练集、距离度量、$k$值及分类决策规则确定后，其结果唯一确定。\n",
    "\n",
    "3．$k$近邻法三要素：距离度量、$k$值的选择和分类决策规则。常用的距离度量是欧氏距离及更一般的**pL**距离。$k$值小时，$k$近邻模型更复杂；$k$值大时，$k$近邻模型更简单。$k$值的选择反映了对近似误差与估计误差之间的权衡，通常由交叉验证选择最优的$k$。\n",
    "\n",
    "常用的分类决策规则是多数表决，对应于经验风险最小化。\n",
    "\n",
    "4．$k$近邻法的实现需要考虑如何快速搜索k个最近邻点。**kd**树是一种便于对k维空间中的数据进行快速检索的数据结构。kd树是二叉树，表示对$k$维空间的一个划分，其每个结点对应于$k$维空间划分中的一个超矩形区域。利用**kd**树可以省去对大部分数据点的搜索， 从而减少搜索的计算量。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 距离度量"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "设特征空间$x$是$n$维实数向量空间 ，$x_{i}, x_{j} \\in \\mathcal{X}$,$x_{i}=\\left(x_{i}^{(1)}, x_{i}^{(2)}, \\cdots, x_{i}^{(n)}\\right)^{\\mathrm{T}}$,$x_{j}=\\left(x_{j}^{(1)}, x_{j}^{(2)}, \\cdots, x_{j}^{(n)}\\right)^{\\mathrm{T}}$\n",
    "，则：$x_i$,$x_j$的$L_p$距离定义为:\n",
    "\n",
    "\n",
    "$L_{p}\\left(x_{i}, x_{j}\\right)=\\left(\\sum_{i=1}^{n}\\left|x_{i}^{(i)}-x_{j}^{(l)}\\right|^{p}\\right)^{\\frac{1}{p}}$\n",
    "\n",
    "- $p= 1$  曼哈顿距离\n",
    "- $p= 2$  欧氏距离\n",
    "- $p= inf$   闵式距离minkowski_distance "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "from itertools import combinations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def L(x, y, p=2):\n",
    "    if len(x) == len(y) and len(x) > 1:\n",
    "        sum = 0\n",
    "        for i in range(len(x)):\n",
    "            sum += math.pow(abs(x[i] - y[i]), p)\n",
    "        return math.pow(sum, 1 / p)\n",
    "    else:\n",
    "        return 0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 课本例3.1\n",
    "\n",
    "已知二维空间的3个点x<sub>1</sub>=(1,1)<sup>T</sup>，x<sub>2</sub>=（5，1）<sup>T</sup>，x<sub>3</sub>=(4,4)<sup>T</sup>，试求在p取不同值时，L<sub>p</sub>距离下x<sub>1</sub>的最近邻点。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "x1 = [1, 1]\n",
    "x2 = [5, 1]\n",
    "x3 = [4, 4]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(4.0, '最近点：[5, 1]')\n",
      "(4.0, '最近点：[5, 1]')\n",
      "(3.7797631496846193, '最近点：[4, 4]')\n",
      "(3.5676213450081633, '最近点：[4, 4]')\n"
     ]
    }
   ],
   "source": [
    "#考虑距离度量中p为1到4的情况\n",
    "for i in range(1, 5):\n",
    "    r = {'最近点：{}'.format(c): L(x1, c, p=i) for c in [x2, x3]}\n",
    "    print(min(zip(r.values(), r.keys())))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# KNN实现"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1.导入必要的库"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "from sklearn.datasets import load_iris\n",
    "from sklearn.model_selection import train_test_split\n",
    "from collections import Counter"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2.导入数据集"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 这里使用的是鸢尾属植物数据集，该数据集测量了所有150个样本的4个特征，分别是：\n",
    "#### sepal length（花萼长度）、sepal width（花萼宽度）、petal length（花瓣长度）、petal width（花瓣宽度）\n",
    "#### 此实验中，我们只考察前两个特征"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "#将导入的数据设置为DataFrame格式，并设置其列名\n",
    "iris = load_iris()\n",
    "df = pd.DataFrame(iris.data, columns=iris.feature_names)\n",
    "df['label'] = iris.target\n",
    "df.columns = ['sepal length', 'sepal width', 'petal length', 'petal width', 'label']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>sepal length</th>\n",
       "      <th>sepal width</th>\n",
       "      <th>petal length</th>\n",
       "      <th>petal width</th>\n",
       "      <th>label</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>5.1</td>\n",
       "      <td>3.5</td>\n",
       "      <td>1.4</td>\n",
       "      <td>0.2</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>4.9</td>\n",
       "      <td>3.0</td>\n",
       "      <td>1.4</td>\n",
       "      <td>0.2</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>4.7</td>\n",
       "      <td>3.2</td>\n",
       "      <td>1.3</td>\n",
       "      <td>0.2</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>4.6</td>\n",
       "      <td>3.1</td>\n",
       "      <td>1.5</td>\n",
       "      <td>0.2</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>5.0</td>\n",
       "      <td>3.6</td>\n",
       "      <td>1.4</td>\n",
       "      <td>0.2</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>145</th>\n",
       "      <td>6.7</td>\n",
       "      <td>3.0</td>\n",
       "      <td>5.2</td>\n",
       "      <td>2.3</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>146</th>\n",
       "      <td>6.3</td>\n",
       "      <td>2.5</td>\n",
       "      <td>5.0</td>\n",
       "      <td>1.9</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>147</th>\n",
       "      <td>6.5</td>\n",
       "      <td>3.0</td>\n",
       "      <td>5.2</td>\n",
       "      <td>2.0</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>148</th>\n",
       "      <td>6.2</td>\n",
       "      <td>3.4</td>\n",
       "      <td>5.4</td>\n",
       "      <td>2.3</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>149</th>\n",
       "      <td>5.9</td>\n",
       "      <td>3.0</td>\n",
       "      <td>5.1</td>\n",
       "      <td>1.8</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>150 rows × 5 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "     sepal length  sepal width  petal length  petal width  label\n",
       "0             5.1          3.5           1.4          0.2      0\n",
       "1             4.9          3.0           1.4          0.2      0\n",
       "2             4.7          3.2           1.3          0.2      0\n",
       "3             4.6          3.1           1.5          0.2      0\n",
       "4             5.0          3.6           1.4          0.2      0\n",
       "..            ...          ...           ...          ...    ...\n",
       "145           6.7          3.0           5.2          2.3      2\n",
       "146           6.3          2.5           5.0          1.9      2\n",
       "147           6.5          3.0           5.2          2.0      2\n",
       "148           6.2          3.4           5.4          2.3      2\n",
       "149           5.9          3.0           5.1          1.8      2\n",
       "\n",
       "[150 rows x 5 columns]"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.legend.Legend at 0x1d4636287c0>"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEJCAYAAACZjSCSAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAff0lEQVR4nO3df5wddX3v8de7YTVRgVRYK2QTokJzlYAEVjDGooItGiKkSLnwqHpRH031WsUHFR/FWvViW7RYtci9IIpVizc0WgyK/NALoqglND8gkcQIFm024WoabwJIkLB+7h8zm2wOZ3fP7J7vOTNn3s/HYx+7Z86c735mBvaTmfl85quIwMzM6uu3uh2AmZl1lxOBmVnNORGYmdWcE4GZWc05EZiZ1ZwTgZlZzSVPBJKmSVon6cYm771S0i5J9+RfH0gdj5mZ7e+ADvyOC4BNwEFjvH9nRCzpQBxmZtZE0kQgaQA4Hfgb4MJ2jHnooYfG3Llz2zGUmVltrFmz5j8jor/Ze6nPCD4JvBc4cJx1Fkq6F9gGvCci7htvwLlz57J69er2RWhmVgOSfjbWe8nuEUhaAvwiItaMs9pa4IiIeDHwKWDlGGMtk7Ra0urt27e3P1gzsxpLebN4EXCGpJ8C1wGnSLp29AoR8XBEPJr/fBPQJ+nQxoEi4uqIGIyIwf7+pmc2ZmY2SckSQURcHBEDETEXOBe4PSLeMHodSc+VpPznE/N4dqSKyczMnqoTVUP7kfQ2gIi4CjgbeLukJ4HdwLnhx6GaWUnt2bOHoaEhHn/88W6HMqbp06czMDBAX19fy59R1f7uDg4Ohm8Wm1k3PPjggxx44IEccsgh5BczSiUi2LFjB4888gjPe97z9ntP0pqIGGz2uY6fEZjVxcp1W7ns1s1s27mbw2fO4KLT5rF0waxuh2VT8PjjjzN37txSJgEASRxyyCEULapxIjBLYOW6rVx8/QZ27xkGYOvO3Vx8/QYAJ4OKK2sSGDGZ+PysIbMELrt1894kMGL3nmEuu3VzlyIyG5sTgVkC23buLrTcrFW33HIL8+bN48gjj+QjH/lIW8Z0IjBL4PCZMwotN2vF8PAw73jHO7j55pvZuHEjy5cvZ+PGjVMe14nALIGLTpvHjL5p+y2b0TeNi06b16WIrBtWrtvKoo/czvP+4hss+sjtrFy3dUrj3X333Rx55JE8//nP52lPexrnnnsuN9xww5TjdCIwS2DpgllcetYxzJo5AwGzZs7g0rOO8Y3iGhkpGNi6czfBvoKBqSSDrVu3Mnv27L2vBwYG2Lp1askFXDVklszSBbP8h7/GxisYmOx/F836vtpRxeQzAjOzBFIUDAwMDLBly5a9r4eGhjj88MMnPd4IJwIzswRSFAy85CUv4f777+fBBx/kiSee4LrrruOMM86Y9HgjnAjMzBJIUTBwwAEHcMUVV3Daaafxwhe+kHPOOYejjz56qqH6HoGZWQoj9wHa/ZiRxYsXs3jx4naEuJcTgZlZIlUpGPClITOzmnMiMDOrOScCM7OacyIwM6s53yy22vMEMlZ3PiOwWkvxPBizlN7ylrfwnOc8h/nz57dtTCcCqzVPIGNVc/7553PLLbe0dUwnAqs1TyBjSa1fAZ+YDx+amX1fv2LKQ5588sk8+9nPnnpsozgRWK15AhlLZv0K+Pq7YNcWILLvX39XW5JBuzkRWK15AhlL5rZLYE/DmeWe3dnyknHVkNVaqufBmLFrqNjyLnIisNqryvNgrGIOHsgvCzVZXjK+NGRd0+75XM1K5dQPQF/Dvaa+GdnyKTjvvPNYuHAhmzdvZmBggGuuuWZK44HPCKxLRur3R0o3R+r3Af/r3HrDsedk32+7JLscdPBAlgRGlk/S8uXL2xDc/pwIrCtSzOdqVjrHnjPlP/yd4EtD1hWu3zcrDycC6wrX71tVRUS3QxjXZOJzIrCucP2+VdH06dPZsWNHaZNBRLBjxw6mT59e6HO+R2Bd4fp9q6KBgQGGhobYvn17t0MZ0/Tp0xkYKFaiqtSZTdI0YDWwNSKWNLwn4B+AxcBjwPkRsXa88QYHB2P16tWpwjUz60mS1kTEYLP3OnFGcAGwCTioyXuvBY7Kv04Crsy/m9WO50Wwbkl6j0DSAHA68NkxVjkT+GJk7gJmSjosZUxmZeR5EaybUt8s/iTwXuA3Y7w/Cxjdgz2ULzOrFc+LYN2ULBFIWgL8IiLWjLdak2VPuWkhaZmk1ZJWl/kmjdlkua/CuinlGcEi4AxJPwWuA06RdG3DOkPA7FGvB4BtjQNFxNURMRgRg/39/aniNesa91VYNyVLBBFxcUQMRMRc4Fzg9oh4Q8NqXwPepMxLgV0R8VCqmMzKyn0V1k0d7yOQ9DaAiLgKuImsdPQBsvLRN3c6HrMycF+FdVPyPoJ2cx+BmVlx3e4jMOuo96/cwPJVWxiOYJrEeSfN5q+XHtPtsMxKy4nAesr7V27g2rv+Y+/r4Yi9r50MzJrzQ+espyxf1WRqwHGWm5kTgfWY4THueY213MycCKzHTFOzHsWxl5uZE4H1mPNOml1ouZn5ZrH1mJEbwq4aMmud+wjMzGpgvD4CXxoyM6s5Xxqytvrjz/wr3//JL/e+XvSCZ/OlP1nYxYi6xxPNWFX4jMDapjEJAHz/J7/kjz/zr12KqHs80YxViROBtU1jEphoeS/zRDNWJU4EZgl4ohmrEicCswQ80YxViROBtc2iFzy70PJe5olmrEqcCKxtvvQnC5/yR7+uVUNLF8zi0rOOYdbMGQiYNXMGl551jKuGrJTcUGZmVgOemMY6JlXtfJFxXb9vVowTgbXNSO38SNnkSO08MKU/xEXGTRWDWS/zPQJrm1S180XGdf2+WXFOBNY2qWrni4zr+n2z4pwIrG1S1c4XGdf1+2bFORFY26SqnS8yruv3zYrzzWJrm5Gbse2u2CkybqoYzHqZ+wjMzGrAfQQlUZb6dtfkm9loTgQdUpb6dtfkm1kj3yzukLLUt7sm38waORF0SFnq212Tb2aNnAg6pCz17a7JN7NGTgQdUpb6dtfkm1kj3yzukLLUt7sm38wauY/AzKwGutJHIGk68F3g6fnv+UpEfLBhnVcCNwAP5ouuj4hLUsVkxb1/5QaWr9rCcATTJM47aTZ/vfSYtqxflh6FssRh1i0TJgJJTwdeD8wdvX4Lf7B/DZwSEY9K6gO+J+nmiLirYb07I2JJsbCtE96/cgPX3vUfe18PR+x93eyPe5H1y9KjUJY4zLqplZvFNwBnAk8Cvxr1Na7IPJq/7Mu/qnUdquaWr9qSbHlZehTKEodZN7VyaWggIl4zmcElTQPWAEcC/zMiVjVZbaGke4FtwHsi4r4m4ywDlgHMmTNnMqHYJAyPcf+oHcvL0qNQljjMuqmVM4IfSBr7ovA4ImI4Io4DBoATJc1vWGUtcEREvBj4FLByjHGujojBiBjs7++fTCg2CdOkZMvL0qNQljjMumnMRCBpg6T1wMuBtZI2S1o/annLImIncAfwmoblD49cPoqIm4A+SYcW3AZL5LyTZidbXpYehbLEYdZN410amtINXEn9wJ6I2ClpBvBq4KMN6zwX+HlEhKQTyRLTjqn8XmufkRu8rVYBFVm/LD0KZYnDrJsm7COQ9E8R8caJljX53LHAF4BpZH/gV0TEJZLeBhARV0n6M+DtZDeidwMXRsQPxhvXfQRmZsVNtY/g6IbBpgEnTPShiFgPLGiy/KpRP18BXNFCDGZmlsiYiUDSxcD7gBmSHh5ZDDwBXN2B2HpOysaloo1fqcYtw6Q3qfZFZa1fAbddAruG4OABOPUDcOw53Y7KSmTMRBARlwKXSro0Ii7uYEw9KWXjUtHGr1TjlmHSm1T7orLWr4Cvvwv25OWwu7Zkr8HJwPYar2roeEnHA18e+Xn0Vwdj7AkpG5eKNnilGrcMk96k2heVddsl+5LAiD27s+VmufHuEfx9/n06MAjcS3Zp6FhgFVlZqbUoZeNS0QavVOOWYdKbVPuisnYNFVtutTTmGUFEvCoiXgX8DDg+b+g6gewG8AOdCrBXpGxcKtrglWrcMkx6k2pfVNbBA8WWWy210ln8XyJiw8iLiPghcFyyiHpUysalog1eqcYtw6Q3qfZFZZ36AehrSK59M7LlZrlWykc3SfoscC3ZQ+PeAGxKGlUPStm4VLTxK9W4ZZj0JtW+qKyRG8KuGrJxtNJQNp2s6evkfNF3gSsj4vHEsTXlhjIzs+Km1FCW/8H/RP5lNVO01t+TvNiY3M9QWuM1lK2IiHMkbaDJPAIRcWzSyKzritb6e5IXG5P7GUptvJvFF+TflwCva/JlPa5orb8nebExuZ+h1MbrLH4o//FUsukk7+9MSFYWRWv9PcmLjcn9DKXWSvnoXODTkn4iaYWkd0o6Lm1YVgZFa/09yYuNyf0MpTZhIoiID0TEKcB84HvARWTTT1qPK1rr70lebEzuZyi1CauGJL0fWAQ8C1gHvAe4M3FcVgJFa/09yYuNyf0MpdZKH8FasoljvgF8B7irWz0E4D4CM7PJmGofwfGSDiR7yNzvA5+R9POI6NmHzqWqhS86bhmeq+++gJLq9Zr8Xt++ohLvj1YuDc0Hfg94BdlTSLfQw5eGUtXCFx23DM/Vd19ASfV6TX6vb19RHdgfrVQNfRQ4ELgceGH+VNKevcOTqha+6LhleK6++wJKqtdr8nt9+4rqwP5o5dLQ6W37bRWQqha+6LhleK6++wJKqtdr8nt9+4rqwP5o5YygVlLVwhcdtwzP1XdfQEn1ek1+r29fUR3YH04EDVLVwhcdtwzP1XdfQEn1ek1+r29fUR3YH63MR1ArqWrhi45bhufquy+gpHq9Jr/Xt6+oDuyPMfsIJH2dJk8dHRERZ7QtigLcR2BmVtxk+wg+liie2kpZk19k7DL0J5hVwo0XwprPQwyDpsEJ58OSj7dn7BL1Soz39NHvdDKQXpeyJr/I2GXoTzCrhBsvhNXX7Hsdw/teTzUZlKxXYsKbxZKOkvQVSRsl/fvIVyeC6yUpa/KLjF2G/gSzSljz+WLLiyhZr0QrVUP/CFxJ9ryhVwFfBP4pZVC9KGVNfpGxy9CfYFYJMVxseREl65VoJRHMiIjbyG4s/ywiPgSckjas3pOyJr/I2GXoTzCrBE0rtryIkvVKtJIIHpf0W8D9kv5M0h8Cz0kcV89JWZNfZOwy9CeYVcIJ5xdbXkTJeiVa6SN4N/AM4F3Ah8nOBv5bwph6Usqa/CJjl6E/wawSRm4Ip6gaKlmvxITzEexdUToIiIh4JG1I43MfgZlZceP1EbRSNTQoaQOwHtgg6V5JJ7TwuemS7s7Xv0/S/2iyjiRdLukBSeslHd/KBpmZWfu0cmnoc8B/j4g7ASS9nKyS6NgJPvdr4JSIeFRSH/A9STdHxF2j1nktcFT+dRJZddJJBbdhQkUbuao4GUuRJrEi21fFfZG0UadIg1HKOFKNXaImp2SKbGMd9getJYJHRpIAQER8T9KEl4ciu+b0aP6yL/9qvA51JvDFfN27JM2UdFhEPNRa+BMr2shVxclYijSJFdm+Ku6LpI06RRqMUsaRauySNTklUWQb67A/cq1UDd0t6dOSXinpFZL+F3CHpOMnupQjaZqke4BfAN+KiFUNq8wim/FsxFC+rG2KNnJVcTKWIk1iRbavivsiaaNOkQajlHGkGrtkTU5JFNnGOuyPXCtnBMfl3z/YsPxlZP/CH7OnICKGgeMkzQS+Kml+RPxw1CrNitefcvda0jJgGcCcOXNaCHmfoo1cVZyMpUiTWJHtq+K+SNqoU6TBKGUcqcYuWZNTEkW2sQ77IzfhGUE+NeVYXy01lkXETuAO4DUNbw0BowvYB4BtTT5/dUQMRsRgf39/K79yr6KNXFWcjKVIk1iR7avivkjaqFOkwShlHKnGLlmTUxJFtrEO+yPXStXQ70i6RtLN+esXSXprC5/rz88EkDQDeDXwo4bVvga8Ka8eeimwq533B6B4I1cVJ2Mp0iRWZPuquC+SNuoUaTBKGUeqsUvW5JREkW2sw/7ItXJp6PNkVUJ/mb/+MfDPwDVjfSB3GPAFSdPIEs6KiLhR0tsAIuIq4CZgMfAA8Bjw5qIbMJGijVxVnIylSJNYke2r4r5I2qhTpMEoZRypxi5Zk1MSRbaxDvsjN2FDmaR/i4iXSFoXEQvyZfdExHGdCLCRG8rMzIqb7MQ0I34l6RDym7gjl3DaGF/pVLJ23jqjijXoKWOuYj9DWY5LibSSCC4ku5b/AknfB/qBs5NG1UWVrJ23zqhiDXrKmKvYz1CW41IyrVQNrQVeQVYu+qfA0RGxPnVg3VLJ2nnrjCrWoKeMuYr9DGU5LiXTStXQH5HNSXAfsBT4515+JlAla+etM6pYg54y5ir2M5TluJRMK53FfxURj+TPGDoN+ALZM4F6UiVr560zqliDnjLmKvYzlOW4lEwriWDkOsnpwJURcQPwtHQhdVcla+etM6pYg54y5ir2M5TluJRMK4lgq6RPA+cAN0l6eoufq6SlC2Zx6VnHMGvmDATMmjmDS886xjeKLbuZ+LrL4eDZgLLvr7t87Br0VtetasyptjHlvivLcSmZVvoInkH2aIgNEXG/pMOAYyLim50IsJH7CMzMiptSH0FEPAZcP+r1Q0BbHwNh1pOKzF1QFlWMuSx9AWWJYxJa6SMws6KKzF1QFlWMuSx9AWWJY5J69lq/WVcVmbugLKoYc1n6AsoSxyQ5EZilUGTugrKoYsxl6QsoSxyT5ERglkKRuQvKoooxl6UvoCxxTJITgVkKReYuKIsqxlyWvoCyxDFJTgRmKSz5OAy+dd+/pjUte13Wm65QzZjL0hdQljgmacI+grJxH4GZWXFTnY/ALI0q1l2njDlVDX8V97N1lBOBdUcV665Txpyqhr+K+9k6zvcIrDuqWHedMuZUNfxV3M/WcU4E1h1VrLtOGXOqGv4q7mfrOCcC644q1l2njDlVDX8V97N1nBOBdUcV665Txpyqhr+K+9k6zonAuqOKddcpY05Vw1/F/Wwd5z4CM7MaGK+PwGcEZutXwCfmw4dmZt/Xr+jOuKniMJuA+wis3lLV2Rcd1/X+1kU+I7B6S1VnX3Rc1/tbFzkRWL2lqrMvOq7r/a2LnAis3lLV2Rcd1/X+1kVOBFZvqersi47ren/rIicCq7dUdfZFx3W9v3WR+wjMzGqgK30EkmZL+rakTZLuk3RBk3VeKWmXpHvyL58Hm5l1WMo+gieBP4+ItZIOBNZI+lZEbGxY786IWJIwDuukKk6CUiTmKm5fWXjflVayRBARDwEP5T8/ImkTMAtoTATWK6rYFFUk5ipuX1l435VaR24WS5oLLABWNXl7oaR7Jd0s6ehOxGOJVLEpqkjMVdy+svC+K7Xkj5iQ9CzgX4B3R8TDDW+vBY6IiEclLQZWAkc1GWMZsAxgzpw5aQO2yatiU1SRmKu4fWXhfVdqSc8IJPWRJYEvRcT1je9HxMMR8Wj+801An6RDm6x3dUQMRsRgf39/ypBtKqrYFFUk5ipuX1l435VayqohAdcAmyKi6UPVJT03Xw9JJ+bx7EgVkyVWxaaoIjFXcfvKwvuu1FJeGloEvBHYIOmefNn7gDkAEXEVcDbwdklPAruBc6NqjQ22z8hNvypVhhSJuYrbVxbed6XmhjIzsxoYr6HM8xHUkeu593fjhbDm8xDD2RSRJ5w/9SkizSrEiaBuXM+9vxsvhNXX7Hsdw/teOxlYTfihc3Xjeu79rfl8seVmPciJoG5cz72/GC623KwHORHUjeu596dpxZab9SAngrpxPff+Tji/2HKzHuREUDeeAGV/Sz4Og2/ddwagadlr3yi2GnEfgZlZDbiPIKGV67Zy2a2b2bZzN4fPnMFFp81j6YJZ3Q6rferQc1CHbSwD7+fSciKYgpXrtnLx9RvYvSerMNm6czcXX78BoDeSQR16DuqwjWXg/VxqvkcwBZfdunlvEhixe88wl926uUsRtVkdeg7qsI1l4P1cak4EU7Bt5+5CyyunDj0HddjGMvB+LjUngik4fOaMQssrpw49B3XYxjLwfi41J4IpuOi0eczo27/xaEbfNC46bV6XImqzOvQc1GEby8D7udR8s3gKRm4I92zVUB2eIV+HbSwD7+dScx+BmVkNjNdH4EtDZr1u/Qr4xHz40Mzs+/oV1RjbOsaXhsx6Wcr6ffcG9AyfEZj1spT1++4N6BlOBGa9LGX9vnsDeoYTgVkvS1m/796AnuFEYNbLUtbvuzegZzgRmPWylPNPeG6LnuE+AjOzGnAfgZmZjcmJwMys5pwIzMxqzonAzKzmnAjMzGrOicDMrOacCMzMas6JwMys5pIlAkmzJX1b0iZJ90m6oMk6knS5pAckrZd0fKp4zMysuZRnBE8Cfx4RLwReCrxD0osa1nktcFT+tQy4MmE8NhmeeMSs5yVLBBHxUESszX9+BNgENE7meybwxcjcBcyUdFiqmKygkYlHdm0BYt/EI04GZj2lI/cIJM0FFgCrGt6aBWwZ9XqIpyYL6xZPPGJWC8kTgaRnAf8CvDsiHm58u8lHnvIUPEnLJK2WtHr79u0pwrRmPPGIWS0kTQSS+siSwJci4vomqwwBs0e9HgC2Na4UEVdHxGBEDPb396cJ1p7KE4+Y1ULKqiEB1wCbIuLjY6z2NeBNefXQS4FdEfFQqpisIE88YlYLByQcexHwRmCDpHvyZe8D5gBExFXATcBi4AHgMeDNCeOxokYmGLntkuxy0MEDWRLwxCNmPcUT05iZ1YAnpjEzszE5EZiZ1ZwTgZlZzTkRmJnVnBOBmVnNVa5qSNJ24GfdjqOJQ4H/7HYQCfX69kHvb6O3r/qmso1HRETTjtzKJYKykrR6rNKsXtDr2we9v43evupLtY2+NGRmVnNOBGZmNedE0D5XdzuAxHp9+6D3t9HbV31JttH3CMzMas5nBGZmNedEUJCkaZLWSbqxyXuvlLRL0j35V+We1yzpp5I25PE/5el++SPDL5f0gKT1ko7vRpxT0cI2Vvo4Spop6SuSfiRpk6SFDe9X+hi2sH1VP37zRsV+j6SHJb27YZ22HsOUj6HuVReQzb980Bjv3xkRSzoYTwqvioixapVfCxyVf50EXJl/r5rxthGqfRz/AbglIs6W9DTgGQ3vV/0YTrR9UOHjFxGbgeMg+4cnsBX4asNqbT2GPiMoQNIAcDrw2W7H0kVnAl+MzF3ATEmHdTsoy0g6CDiZbFIoIuKJiNjZsFplj2GL29dLTgV+EhGNTbRtPYZOBMV8Engv8Jtx1lko6V5JN0s6ujNhtVUA35S0RtKyJu/PAraMej2UL6uSibYRqnscnw9sB/4xv4T5WUnPbFinysewle2D6h6/RucCy5ssb+sxdCJokaQlwC8iYs04q60la+N+MfApYGUnYmuzRRFxPNmp5zskndzwvpp8pmqlZxNtY5WP4wHA8cCVEbEA+BXwFw3rVPkYtrJ9VT5+e+WXvc4Avtzs7SbLJn0MnQhatwg4Q9JPgeuAUyRdO3qFiHg4Ih7Nf74J6JN0aMcjnYKI2JZ//wXZdckTG1YZAmaPej0AbOtMdO0x0TZW/DgOAUMRsSp//RWyP5yN61T1GE64fRU/fqO9FlgbET9v8l5bj6ETQYsi4uKIGIiIuWSna7dHxBtGryPpuZKU/3wi2f7d0fFgJ0nSMyUdOPIz8AfADxtW+xrwprxq4aXAroh4qMOhTlor21jl4xgR/xfYImlevuhUYGPDapU9hq1sX5WPX4PzaH5ZCNp8DF01NEWS3gYQEVcBZwNvl/QksBs4N6rVsfc7wFfz/4cOAP53RNzSsI03AYuBB4DHgDd3KdbJamUbq34c3wl8Kb+08O/Am3vsGE60fVU/fkh6BvD7wJ+OWpbsGLqz2Mys5nxpyMys5pwIzMxqzonAzKzmnAjMzGrOicDMrOacCMwKyp9uOdbTZ5+yvA2/b6mkF416fYeknp6b1zrLicCs/JYCL5poJbPJciKwnpN3D38jf+jYDyX913z5CZK+kz9s7taRpzXm/8L+pKQf5OufmC8/MV+2Lv8+b7zf2ySGz0n6t/zzZ+bLz5d0vaRbJN0v6e9Gfeatkn6cx/MZSVdIehnZ82YuU/Zs+hfkq/+RpLvz9X+vTbvOasqdxdaLXgNsi4jTASQdLKmP7AFkZ0bE9jw5/A3wlvwzz4yIl+UPoPscMB/4EXByRDwp6dXA3wKvbzGGvyR7DMlbJM0E7pb0f/L3jgMWAL8GNkv6FDAM/BXZc3MeAW4H7o2IH0j6GnBjRHwl3x6AAyLiREmLgQ8Cry6+m8wyTgTWizYAH5P0UbI/oHdKmk/2x/1b+R/SacDoZ7MsB4iI70o6KP/jfSDwBUlHkT3Zsa9ADH9A9pDC9+SvpwNz8p9vi4hdAJI2AkcAhwLfiYhf5su/DPzuOONfn39fA8wtEJfZUzgRWM+JiB9LOoHsWSyXSvom2VNG74uIhWN9rMnrDwPfjog/lDQXuKNAGAJen882tW+hdBLZmcCIYbL/D5s9Vng8I2OMfN5s0nyPwHqOpMOBxyLiWuBjZJdbNgP9yue3ldSn/ScsGbmP8HKyJznuAg4mmyYQ4PyCYdwKvHPUUzAXTLD+3cArJP22pAPY/xLUI2RnJ2ZJ+F8S1ouOIbu5+htgD/D2iHhC0tnA5ZIOJvtv/5PAffln/p+kH5DNRT1y3+DvyC4NXUh2zb6ID+fjr8+TwU+BMefQjYitkv4WWEX2XPmNwK787euAz0h6F9mTNc3ayk8ftdqTdAfwnohY3eU4nhURj+ZnBF8FPhcRjZOWm7WdLw2ZlceHJN1DNlHOg1R0ikWrHp8RmJnVnM8IzMxqzonAzKzmnAjMzGrOicDMrOacCMzMas6JwMys5v4/HDGcEogo/iEAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# 绘制前两类数据（标签为0和1的数据）的散点图，（只考虑sepal length和sepal width这两个属性）\n",
    "plt.scatter(df[:50]['sepal length'], df[:50]['sepal width'], label='0')\n",
    "plt.scatter(df[50:100]['sepal length'], df[50:100]['sepal width'], label='1')\n",
    "plt.xlabel('sepal length')\n",
    "plt.ylabel('sepal width')\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3.数据集切分"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 只考虑数据的前两列属性以及最后一列的标签\n",
    "data = np.array(df.iloc[:100, [0, 1, -1]])\n",
    "# print(\"data:\",data)\n",
    "X, y = data[:,:-1], data[:,-1]\n",
    "# print(\"X:\",X)\n",
    "# print(\"y:\",y)\n",
    "\n",
    "# 从样本中，随机按照80%和20%的比例抽取训练集和测试集\n",
    "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)\n",
    "# print(\"X_train:\",X_train)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4.实现KNN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "class KNN:\n",
    "    def __init__(self, X_train, y_train, n_neighbors=3, p=2):\n",
    "        \"\"\"\n",
    "        parameter: n_neighbors 临近点个数，即k值\n",
    "        parameter: p 距离度量\n",
    "        \"\"\"\n",
    "        self.n = n_neighbors\n",
    "        self.p = p\n",
    "        self.X_train = X_train\n",
    "        self.y_train = y_train\n",
    "\n",
    "    # 预测未知类别的点X所属的类别\n",
    "    def predict(self, X):\n",
    "        # 先取出前n个点，分别计算X与这n个点的范数\n",
    "        knn_list = []\n",
    "        for i in range(self.n):\n",
    "            dist = np.linalg.norm(X - self.X_train[i], ord=self.p)\n",
    "            knn_list.append((dist, self.y_train[i]))\n",
    "\n",
    "        # 对于n个点后面的所有点，计算其与X的范数，如果值小于上面前n个点范数的最大值，则进行替换，最后得到所有点中，距离X最近的n个点\n",
    "        for i in range(self.n, len(self.X_train)):\n",
    "            max_index = knn_list.index(max(knn_list, key=lambda x: x[0]))\n",
    "            dist = np.linalg.norm(X - self.X_train[i], ord=self.p)\n",
    "            if knn_list[max_index][0] > dist:\n",
    "                knn_list[max_index] = (dist, self.y_train[i])\n",
    "\n",
    "        # 分类决策，n个点中，多数点所属的类别即为预测所得的X所属类别\n",
    "        knn = [k[-1] for k in knn_list]\n",
    "        count_pairs = Counter(knn)\n",
    "        max_count = sorted(count_pairs.items(), key=lambda x: x[1])[-1][0]\n",
    "        return max_count\n",
    "\n",
    "    # 得分函数，用测试数据集进行测试，得到此方法的预测准确度\n",
    "    def score(self, X_test, y_test):\n",
    "        right_count = 0\n",
    "        n = 10\n",
    "        for X, y in zip(X_test, y_test):\n",
    "            label = self.predict(X)\n",
    "            if label == y:\n",
    "                right_count += 1\n",
    "        return right_count / len(X_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5.创建KNN实例"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "clf = KNN(X_train, y_train)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 6.模型准确率"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "准确率：100.00%\n"
     ]
    }
   ],
   "source": [
    "print('准确率：{:.2%}'.format(clf.score(X_test, y_test)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 7.使用新的测试点进行预测"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "测试点X所属类别:1.0\n"
     ]
    }
   ],
   "source": [
    "test_point = [6.0, 3.0]\n",
    "print('测试点X所属类别:{}'.format(clf.predict(test_point)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.legend.Legend at 0x1d463673250>"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEJCAYAAACZjSCSAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAji0lEQVR4nO3de5xVdb3/8dfHcRJQghI6CQNOXn78SkAu4y285KXwgklkpklG+jiEx5OWSUV6NC2PevSXHvSXhmWZIkoewjJESyUtjxg3QUVD88LFXxIFgqAHx8/vj7UGhs2emb1m9nfvtfZ6Px+PeezZa6+95vNdW/eHtdbns77m7oiISH7tUu0ARESkupQIRERyTolARCTnlAhERHJOiUBEJOeUCEREci54IjCzOjNbbGb3F3ntE2a2wcyWxD+Xho5HRER2tGsF/sYFwHLg/W28/ri7j6lAHCIiUkTQRGBmDcBJwJXAheXYZp8+fbyxsbEcmxIRyY2FCxf+zd37Fnst9BHBDcA3gZ7trHOYmT0NrAEucvdn29tgY2MjCxYsKF+EIiI5YGavtvVasGsEZjYGeMPdF7az2iJgb3c/ELgRmN3Gtiaa2QIzW7B27dryBysikmMhLxaPAj5tZq8AdwPHmNmdrVdw9zfdfVP8+xyg3sz6FG7I3ae5e5O7N/XtW/TIRkREOilYInD3Ke7e4O6NwOnAI+4+vvU6ZvZhM7P494PjeNaFiklERHZWiaqhHZjZJAB3vwU4FTjXzN4FtgCnu26HKpIrW7duZdWqVbz99tvVDqUmdOvWjYaGBurr60t+j2Xte7epqcl1sVikdrz88sv07NmTPffck/gEgXSSu7Nu3To2btzIRz7ykR1eM7OF7t5U7H0VPyIQyYvZi1dz7YMvsGb9Fvr17s7k0YMYO7x/tcNKnbfffpvGxkYlgTIwM/bcc0+SFtUoEYgEMHvxaqbMWsaWrc0ArF6/hSmzlgEoGRShJFA+ndmXuteQSADXPvjCtiTQYsvWZq598IUqRSTSNiUCkQDWrN+SaLlUz/r16/nhD3/YqffecMMNbN68uazxXHrppfzud79rd5158+bxxBNPlO1vKhGIBNCvd/dEy6V006dDYyPsskv0OH1617aXtkRwxRVXcNxxx7W7jhKBSAZMHj2I7vV1OyzrXl/H5NGDqhRRbZg+HSZOhFdfBffoceLEriWDb3/727z00ksMGzaMyZMnc+2113LQQQcxdOhQLrvsMgDeeustTjrpJA488EAGDx7MPffcw9SpU1mzZg1HH300Rx99dJvb32OPPfjGN77BiBEjOPbYY7ddyF2yZAmHHnooQ4cO5TOf+Qz/+Mc/AJgwYQL33nsvEN1S57LLLmPEiBEMGTKE559/nldeeYVbbrmF66+/nmHDhvH44493fvAxJQKRAMYO789V44bQv3d3DOjfuztXjRuiC8VddPHFUPgP8M2bo+WddfXVV7PvvvuyZMkSPvnJT7JixQqeeuoplixZwsKFC3nssceYO3cu/fr14+mnn+aZZ57h+OOP5/zzz6dfv348+uijPProo21u/6233mLEiBEsWrSIo446issvvxyAs846i2uuuYalS5cyZMiQbcsL9enTh0WLFnHuuedy3XXX0djYyKRJk/j617/OkiVLOOKIIzo/+JiqhkQCGTu8v774y+y115ItT+qhhx7ioYceYvjw4QBs2rSJFStWcMQRR3DRRRfxrW99izFjxiT68t1ll134/Oc/D8D48eMZN24cGzZsYP369Rx11FEAfOlLX+Jzn/tc0fePGzcOgJEjRzJr1qyuDK9NSgQikhkDB0ang4otLwd3Z8qUKXzlK1/Z6bWFCxcyZ84cpkyZwqc+9SkuvbRz82glLe/cbbfdAKirq+Pdd9/t1N/siE4NiUhmXHkl9Oix47IePaLlndWzZ082btwIwOjRo7ntttvYtGkTAKtXr+aNN95gzZo19OjRg/Hjx3PRRRexaNGind7blvfee2/bOf+77rqLww8/nF69evGBD3xg2/n9O+64Y9vRQdKYy0FHBCKSGWeeGT1efHF0OmjgwCgJtCzvjD333JNRo0YxePBgTjjhBL7whS9w2GGHAdGF3jvvvJMXX3yRyZMns8suu1BfX8/NN98MwMSJEznhhBPYa6+92rxOsPvuu/Pss88ycuRIevXqxT333APA7bffzqRJk9i8eTP77LMPP/3pT0uO+eSTT+bUU0/lvvvu48Ybb+zydQLda0hEqmr58uV89KMfrXYYweyxxx7bjjAqpdg+be9eQzo1JCKSczo1JCJSBocccgjvvPPODsvuuOOOih8NdIYSgYhIGcyfP7/aIXSaTg2JiOScEoGISM7p1JDkniaQkbzTEYHkWssEMqvXb8HZPoHM7MWrqx2aVNDcuXMZNGgQ++23H1dffXW1w6k4JQLJNU0gI83NzZx33nk88MADPPfcc8yYMYPnnnuu2mFVlBKB5JomkJGnnnqK/fbbj3322Yf3ve99nH766dx3333VDquidI1Acq1f7+6sLvKlrwlk0qvc13RWr17NgAEDtj1vaGjIdCloZ+iIQHJNE8hkS4hrOsVus9OZCeCzTIlAck0TyGRLiGs6DQ0NrFy5ctvzVatW0a9fv05vL4t0akhyTxPIZEeIazoHHXQQK1as4OWXX6Z///7cfffd3HXXXZ3eXhYpEUjVqH5fkgpxTWfXXXflpptuYvTo0TQ3N3P22WdzwAEHdCXMzFEikKpoOdfbcpjfcq4XUDKQNk0ePWiH/26gPNd0TjzxRE488cSuhpdZukYgVaH6fekMXdMJQ0cEUhWq35fO0jWd8tMRgVRFW+d0Vb8vUnlKBFIVqt8XSQ+dGpKqaDm0V9WQSPUFTwRmVgcsAFa7+5iC1wz4T+BEYDMwwd0XhY5J0kHnekXSoRKnhi4Alrfx2gnA/vHPRODmCsQjkkqzF69m1NWP8JFv/4ZRVz+iW2FX0Nlnn82HPvQhBg8eXO1QqiJoIjCzBuAk4MdtrHIK8HOPPAn0NrO9QsYkkkaaF6G6JkyYwNy5c6sdRtWEPiK4Afgm8F4br/cHVrZ6vipeJpIr6quoriOPPJIPfvCD1Q6jaoIlAjMbA7zh7gvbW63Isp1uBWhmE81sgZktWLt2bdliFEkL9VUksHQmXD8Yvts7elw6s9oRZV7II4JRwKfN7BXgbuAYM7uzYJ1VwIBWzxuANYUbcvdp7t7k7k19+/YNFa9I1aivokRLZ8Kvz4cNKwGPHn99vpJBFwVLBO4+xd0b3L0ROB14xN3HF6z2K+AsixwKbHD310PFJJJW6qso0cNXwNaCo6StW6Ll0mkV7yMws0kA7n4LMIeodPRFovLRL1c6HpE0UF9FiTasSrZcSlKRRODu84B58e+3tFruwHmViEEk7dRXUYJeDfFpoSLLu+CMM85g3rx5/O1vf6OhoYHLL7+cc845p0vbzBJ1FkvNuWT2MmbMX0mzO3VmnHHIAL4/dki1w5JyOPbS6JpA69ND9d2j5V0wY8aMLgaWbbrXkNSUS2Yv484nX6M5noe22Z07n3yNS2Yvq3JkUhZDT4OTp0KvAYBFjydPjZZLp+mIQGrKjPlFThvEy3VUUCOGnqYv/jLTEYHUlJYjgVKXi4gSgdSYOivWo9j2ckkHV6Ium87sSyUCqSlnHDIg0XKpvm7durFu3TolgzJwd9atW0e3bt0SvU/XCKSmtFwHUNVQdjQ0NLBq1Sp0+5jy6NatGw0NycppLWtZuKmpyRcsWFDtMEREMsXMFrp7U7HXdGpIRCTndGpIyurMW/+bP770923PR+37Qab/82FVjKh6Zi9erVtGSCboiEDKpjAJAPzxpb9z5q3/XaWIqkcTzUiWKBFI2RQmgY6W1zJNNCNZokQgEoAmmpEsUSIQCUATzUiWKBFI2Yzat/icr20tr2WaaEayRIlAymb6Px+205d+XquGxg7vz1XjhtC/d3cM6N+7O1eNG6KqIUklNZSJiORAew1l6iOQsgpVO59ku6rfF0lGiUDKpqV2vqVssqV2HujSF3GS7YaKQaSW6RqBlE2o2vkk21X9vkhySgRSNqFq55NsV/X7IskpEUjZhKqdT7Jd1e+LJKdEIGUTqnY+yXZVvy+SnC4WS9m0XIwtd8VOku2GikGklqmPQEQkB9RHkBJpqW9XTb6ItKZEUCFpqW9XTb6IFNLF4gpJS327avJFpJASQYWkpb5dNfkiUkiJoELSUt+umnwRKaREUCFpqW9XTb6IFNLF4gpJS327avJFpJD6CEREcqAqfQRm1g14DNgt/jv3uvtlBet8ArgPeDleNMvdrwgVkyR3yexlzJi/kmZ36sw445ABfH/skLKsn5YehbTEIVItHSYCM9sN+CzQ2Hr9Er6w3wGOcfdNZlYP/MHMHnD3JwvWe9zdxyQLWyrhktnLuPPJ17Y9b3bf9rzYl3uS9dPSo5CWOESqqZSLxfcBpwDvAm+1+mmXRzbFT+vjn2ydh8q5GfNXBluelh6FtMQhUk2lnBpqcPfjO7NxM6sDFgL7Af/X3ecXWe0wM3saWANc5O7PFtnORGAiwMCBAzsTinRCcxvXj8qxPC09CmmJQ6SaSjkieMLM2j4p3A53b3b3YUADcLCZDS5YZRGwt7sfCNwIzG5jO9Pcvcndm/r27duZUKQT6syCLU9Lj0Ja4hCppjYTgZktM7OlwOHAIjN7wcyWtlpeMndfD8wDji9Y/mbL6SN3nwPUm1mfhGOQQM44ZECw5WnpUUhLHCLV1N6poS5dwDWzvsBWd19vZt2B44BrCtb5MPBXd3czO5goMa3ryt+V8mm5wFtqFVCS9dPSo5CWOESqqcM+AjO7w92/2NGyIu8bCtwO1BF9wc909yvMbBKAu99iZv8KnEt0IXoLcKG7P9HedtVHICKSXFf7CA4o2FgdMLKjN7n7UmB4keW3tPr9JuCmEmIQEZFA2kwEZjYF+A7Q3czebFkM/A8wrQKx1ZyQjUtJG79CbTcNk96E2heZtXQmPHwFbFgFvRrg2Eth6GnVjkpSpM1E4O5XAVeZ2VXuPqWCMdWkkI1LSRu/Qm03DZPehNoXmbV0Jvz6fNgal8NuWBk9ByUD2aa9qqERZjYC+EXL761/KhhjTQjZuJS0wSvUdtMw6U2ofZFZD1+xPQm02LolWi4Sa+8awf+JH7sBTcDTRKeGhgLzicpKpUQhG5eSNniF2m4aJr0JtS8ya8OqZMsll9o8InD3o939aOBVYETc0DWS6ALwi5UKsFaEbFxK2uAVartpmPQm1L7IrF4NyZZLLpXSWfy/3X1ZyxN3fwYYFiyiGhWycSlpg1eo7aZh0ptQ+yKzjr0U6guSa333aLlIrJTy0eVm9mPgTqKbxo0HlgeNqgaFbFxK2vgVartpmPQm1L7IrJYLwqoaknaU0lDWjajp68h40WPAze7+duDYilJDmYhIcl1qKIu/8K+PfyRnktb6a5IXaZP6GVKrvYayme5+mpkto8g8Au4+NGhkUnVJa/01yYu0Sf0MqdbexeIL4scxwMlFfqTGJa311yQv0ib1M6Rae53Fr8e/Hks0neSKyoQkaZG01l+TvEib1M+QaqWUjzYCPzKzl8xsppl91cyGhQ1L0iBprb8meZE2qZ8h1TpMBO5+qbsfAwwG/gBMJpp+Umpc0lp/TfIibVI/Q6p1WDVkZpcAo4A9gMXARcDjgeOSFEha669JXqRN6mdItVL6CBYRTRzzG+D3wJPV6iEA9RGIiHRGV/sIRphZT6KbzH0SuNXM/uruNXvTuVC18Em3m4b76qsvIKVqvSa/1seXVOD9UcqpocHAEcBRRHchXUkNnxoKVQufdLtpuK+++gJSqtZr8mt9fElVYH+UUjV0DdATmAp8NL4rac1e4QlVC590u2m4r776AlKq1mvya318SVVgf5Ryauiksv21DAhVC590u2m4r776AlKq1mvya318SVVgf5RyRJAroWrhk243DffVV19AStV6TX6tjy+pCuwPJYICoWrhk243DffVV19AStV6TX6tjy+pCuyPUuYjyJVQtfBJt5uG++qrLyClar0mv9bHl1QF9kebfQRm9muK3HW0hbt/umxRJKA+AhGR5DrbR3BdoHhyK2RNfpJtp6E/QSQT7r8QFv4MvBmsDkZOgDE/KM+2U9Qr0d7dR39fyUBqXcia/CTbTkN/gkgm3H8hLPjJ9ufevP15V5NBynolOrxYbGb7m9m9Zvacmf2l5acSwdWSkDX5Sbadhv4EkUxY+LNky5NIWa9EKVVDPwVuJrrf0NHAz4E7QgZVi0LW5CfZdhr6E0QywZuTLU8iZb0SpSSC7u7+MNGF5Vfd/bvAMWHDqj0ha/KTbDsN/QkimWB1yZYnkbJeiVISwdtmtguwwsz+1cw+A3wocFw1J2RNfpJtp6E/QSQTRk5ItjyJlPVKlNJH8DWgB3A+8D2io4EvBYypJoWsyU+y7TT0J4hkQssF4RBVQynrlehwPoJtK5q9H3B33xg2pPapj0BEJLn2+ghKqRpqMrNlwFJgmZk9bWYjS3hfNzN7Kl7/WTO7vMg6ZmZTzexFM1tqZiNKGZCIiJRPKaeGbgP+xd0fBzCzw4kqiYZ28L53gGPcfZOZ1QN/MLMH3P3JVuucAOwf/xxCVJ10SMIxdChpI1cWJ2NJ0iSWZHxZ3BdBG3WSNBiFjCPBtqdPh4svhtdeg4ED4cor4cwzu77dzEoyxjzsD0pLBBtbkgCAu//BzDo8PeTROadN8dP6+KfwPNQpwM/jdZ80s95mtpe7v15a+B1L2siVxclYkjSJJRlfFvdF0EadJA1GIeNIsO3p02HiRNi8OXr+6qvRcyiSDFLW5BREkjHmYX/ESqkaesrMfmRmnzCzo8zsh8A8MxvR0akcM6szsyXAG8Bv3X1+wSr9iWY8a7EqXlY2SRu5sjgZS5ImsSTjy+K+CNqok6TBKGQcCbZ98cXbk0CLzZuj5V3ZbmYlGWMe9keslCOCYfHjZQXLP070L/w2ewrcvRkYZma9gV+a2WB3f6bVKsWK13e6em1mE4GJAAMHDiwh5O2SNnJlcTKWJE1iScaXxX0RtFEnSYNRyDgSbPu114qs19bylDU5BZFkjHnYH7EOjwjiqSnb+impsczd1wPzgOMLXloFtC5gbwDWFHn/NHdvcvemvn37lvInt0nayJXFyViSNIklGV8W90XQRp0kDUYh40iw7bb+3VR0ecqanIJIMsY87I9YKVVD/2RmPzGzB+LnHzOzc0p4X9/4SAAz6w4cBzxfsNqvgLPi6qFDgQ3lvD4AyRu5sjgZS5ImsSTjy+K+CNqok6TBKGQcCbZ95ZXQo8eOy3r0iJZ3ZbuZlWSMedgfsVKuEfwMeBDoFz//M1GTWUf2Ah41s6XAn4iuEdxvZpPMbFK8zhzgL8CLwK3Av5QeemnGDu/PVeOG0L93dwzo37s7V40b0ubFzqTrp8H3xw5h/KEDtx0B1Jkx/tCBRauGkowvi/uCoafByVOh1wDAoseTp5bn4t6YH0DTOduPAKwuel6saihkHAm2feaZMG0a7L03mEWP06a1UTUUMua0SDLGPOyPWIcNZWb2J3c/yMwWu/vweNkSdx9WiQALqaFMRCS5zk5M0+ItM9uT+CJuyymcMsaXOpmsnZfKyGINesiYQ40xJT0YeVFKIriQ6Fz+vmb2R6AvcGrQqKook7XzUhlZrEEPGXOoMaakByNPSqkaWgQcRVQu+hXgAHdfGjqwaslk7bxURhZr0EPGHGqMKenByJNSqoY+RzQnwbPAWOCeWr4nUCZr56UysliDHjLmUGNMSQ9GnpRSNfRv7r4xvsfQaOB2onsC1aRM1s5LZWSxBj1kzKHGmJIejDwpJRG0nCc5CbjZ3e8D3hcupOrKZO28VEYWa9BDxhxqjCnpwciTUhLBajP7EXAaMMfMdivxfZmUydp5qYws1qCHjDnUGFPSg5EnpfQR9CC6NcQyd19hZnsBQ9z9oUoEWEh9BCIiyXWpj8DdNwOzWj1/HSjrbSBEalKSuQvSIosxp6UvIC1xdEIpfQQiklSSuQvSIosxp6UvIC1xdFLNnusXqaokcxekRRZjTktfQFri6CQlApEQksxdkBZZjDktfQFpiaOTlAhEQkgyd0FaZDHmtPQFpCWOTlIiEAkhydwFaZHFmNPSF5CWODpJiUAkhCRzF6RFFmNOS19AWuLopA77CNJGfQQiIsl1dT4CkTCyWHcdMuZQNfxZ3M9SUUoEUh1ZrLsOGXOoGv4s7mepOF0jkOrIYt11yJhD1fBncT9LxSkRSHVkse46ZMyhavizuJ+l4pQIpDqyWHcdMuZQNfxZ3M9ScUoEUh1ZrLsOGXOoGv4s7mepOCUCqY4s1l2HjDlUDX8W97NUnPoIRERyoL0+Ah0RiCydCdcPhu/2jh6XzqzOdkPFIdIB9RFIvoWqs0+6XdX7SxXpiEDyLVSdfdLtqt5fqkiJQPItVJ190u2q3l+qSIlA8i1UnX3S7areX6pIiUDyLVSdfdLtqt5fqkiJQPItVJ190u2q3l+qSH0EIiI5UJU+AjMbYGaPmtlyM3vWzC4oss4nzGyDmS2Jf3QcLCJSYSH7CN4FvuHui8ysJ7DQzH7r7s8VrPe4u48JGIdUUhYnQUkScxbHlxbad6kVLBG4++vA6/HvG81sOdAfKEwEUiuy2BSVJOYsji8ttO9SrSIXi82sERgOzC/y8mFm9rSZPWBmB1QiHgkki01RSWLO4vjSQvsu1YLfYsLM9gD+C/iau79Z8PIiYG9332RmJwKzgf2LbGMiMBFg4MCBYQOWzstiU1SSmLM4vrTQvku1oEcEZlZPlASmu/uswtfd/U133xT/PgeoN7M+Rdab5u5N7t7Ut2/fkCFLV2SxKSpJzFkcX1po36VayKohA34CLHf3ojdVN7MPx+thZgfH8awLFZMElsWmqCQxZ3F8aaF9l2ohTw2NAr4ILDOzJfGy7wADAdz9FuBU4FwzexfYApzuWWtskO1aLvplqTIkScxZHF9aaN+lmhrKRERyoL2GMs1HkEeq597R/RfCwp+BN0dTRI6c0PUpIkUyRIkgb1TPvaP7L4QFP9n+3Ju3P1cykJzQTefyRvXcO1r4s2TLRWqQEkHeqJ57R96cbLlIDVIiyBvVc+/I6pItF6lBSgR5o3ruHY2ckGy5SA1SIsgbTYCyozE/gKZzth8BWF30XBeKJUfURyAikgPqIwho9uLVXPvgC6xZv4V+vbszefQgxg7vX+2wyicPPQd5GGMaaD+nlhJBF8xevJops5axZWtUYbJ6/RamzFoGUBvJIA89B3kYYxpoP6earhF0wbUPvrAtCbTYsrWZax98oUoRlVkeeg7yMMY00H5ONSWCLlizfkui5ZmTh56DPIwxDbSfU02JoAv69e6eaHnm5KHnIA9jTAPt51RTIuiCyaMH0b1+x8aj7vV1TB49qEoRlVkeeg7yMMY00H5ONV0s7oKWC8I1WzWUh3vI52GMaaD9nGrqIxARyYH2+gh0akik1i2dCdcPhu/2jh6XzszGtqVidGpIpJaFrN9Xb0DN0BGBSC0LWb+v3oCaoUQgUstC1u+rN6BmKBGI1LKQ9fvqDagZSgQitSxk/b56A2qGEoFILQs5/4TmtqgZ6iMQEckB9RGIiEiblAhERHJOiUBEJOeUCEREck6JQEQk55QIRERyTolARCTnlAhERHIuWCIwswFm9qiZLTezZ83sgiLrmJlNNbMXzWypmY0IFY+IiBQX8ojgXeAb7v5R4FDgPDP7WME6JwD7xz8TgZsDxiOdoYlHRGpesETg7q+7+6L4943AcqBwMt9TgJ975Emgt5ntFSomSahl4pENKwHfPvGIkoFITanINQIzawSGA/MLXuoPrGz1fBU7JwupFk08IpILwROBme0B/BfwNXd/s/DlIm/Z6S54ZjbRzBaY2YK1a9eGCFOK0cQjIrkQNBGYWT1REpju7rOKrLIKGNDqeQOwpnAld5/m7k3u3tS3b98wwcrONPGISC6ErBoy4CfAcnf/QRur/Qo4K64eOhTY4O6vh4pJEtLEIyK5sGvAbY8CvggsM7Ml8bLvAAMB3P0WYA5wIvAisBn4csB4JKmWCUYeviI6HdSrIUoCmnhEpKZoYhoRkRzQxDQiItImJQIRkZxTIhARyTklAhGRnFMiEBHJucxVDZnZWuDVasdRRB/gb9UOIqBaHx/U/hg1vuzryhj3dveiHbmZSwRpZWYL2irNqgW1Pj6o/TFqfNkXaow6NSQiknNKBCIiOadEUD7Tqh1AYLU+Pqj9MWp82RdkjLpGICKSczoiEBHJOSWChMyszswWm9n9RV77hJltMLMl8U/m7tdsZq+Y2bI4/p3u7hffMnyqmb1oZkvNbEQ14uyKEsaY6c/RzHqb2b1m9ryZLTezwwpez/RnWML4sv75DWoV+xIze9PMvlawTlk/w5C3oa5VFxDNv/z+Nl5/3N3HVDCeEI5297ZqlU8A9o9/DgFujh+zpr0xQrY/x/8E5rr7qWb2PqBHwetZ/ww7Gh9k+PNz9xeAYRD9wxNYDfyyYLWyfoY6IkjAzBqAk4AfVzuWKjoF+LlHngR6m9le1Q5KImb2fuBIokmhcPf/cff1Batl9jMscXy15FjgJXcvbKIt62eoRJDMDcA3gffaWecwM3vazB4wswMqE1ZZOfCQmS00s4lFXu8PrGz1fFW8LEs6GiNk93PcB1gL/DQ+hfljM9u9YJ0sf4aljA+y+/kVOh2YUWR5WT9DJYISmdkY4A13X9jOaouI2rgPBG4EZlcitjIb5e4jiA49zzOzIwtetyLvyVrpWUdjzPLnuCswArjZ3YcDbwHfLlgny59hKePL8ue3TXza69PAL4q9XGRZpz9DJYLSjQI+bWavAHcDx5jZna1XcPc33X1T/PscoN7M+lQ80i5w9zXx4xtE5yUPLlhlFTCg1fMGYE1loiuPjsaY8c9xFbDK3efHz+8l+uIsXCern2GH48v459faCcAid/9rkdfK+hkqEZTI3ae4e4O7NxIdrj3i7uNbr2NmHzYzi38/mGj/rqt4sJ1kZrubWc+W34FPAc8UrPYr4Ky4auFQYIO7v17hUDutlDFm+XN09/8HrDSzQfGiY4HnClbL7GdYyviy/PkVOIPip4WgzJ+hqoa6yMwmAbj7LcCpwLlm9i6wBTjds9Wx90/AL+P/h3YF7nL3uQVjnAOcCLwIbAa+XKVYO6uUMWb9c/wqMD0+tfAX4Ms19hl2NL6sf36YWQ/gk8BXWi0L9hmqs1hEJOd0akhEJOeUCEREck6JQEQk55QIRERyTolARCTnlAhEEorvbtnW3Wd3Wl6GvzfWzD7W6vk8M6vpuXmlspQIRNJvLPCxjlYS6SwlAqk5cffwb+Kbjj1jZp+Pl480s9/HN5t7sOVujfG/sG8wsyfi9Q+Olx8cL1scPw5q7+8WieE2M/tT/P5T4uUTzGyWmc01sxVm9h+t3nOOmf05judWM7vJzD5OdL+Zay26N/2+8eqfM7On4vWPKNOuk5xSZ7HUouOBNe5+EoCZ9TKzeqIbkJ3i7mvj5HAlcHb8nt3d/ePxDehuAwYDzwNHuvu7ZnYc8O/AZ0uM4WKi25CcbWa9gafM7Hfxa8OA4cA7wAtmdiPQDPwb0X1zNgKPAE+7+xNm9ivgfne/Nx4PwK7ufrCZnQhcBhyXfDeJRJQIpBYtA64zs2uIvkAfN7PBRF/uv42/SOuA1vdmmQHg7o+Z2fvjL++ewO1mtj/RnR3rE8TwKaKbFF4UP+8GDIx/f9jdNwCY2XPA3kAf4Pfu/vd4+S+A/9XO9mfFjwuBxgRxiexEiUBqjrv/2cxGEt2L5Soze4joLqPPuvthbb2tyPPvAY+6+2fMrBGYlyAMAz4bzza1faHZIURHAi2aif4/LHZb4fa0bKPl/SKdpmsEUnPMrB+w2d3vBK4jOt3yAtDX4vltzazedpywpOU6wuFEd3LcAPQimiYQYELCMB4EvtrqLpjDO1j/KeAoM/uAme3KjqegNhIdnYgEoX9JSC0aQnRx9T1gK3Cuu/+PmZ0KTDWzXkT/7d8APBu/5x9m9gTRXNQt1w3+g+jU0IVE5+yT+F68/aVxMngFaHMOXXdfbWb/Dswnuq/8c8CG+OW7gVvN7HyiO2uKlJXuPiq5Z2bzgIvcfUGV49jD3TfFRwS/BG5z98JJy0XKTqeGRNLju2a2hGiinJfJ6BSLkj06IhARyTkdEYiI5JwSgYhIzikRiIjknBKBiEjOKRGIiOScEoGISM79f6cK54eG1HRaAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.scatter(df[:50]['sepal length'], df[:50]['sepal width'], label='0')\n",
    "plt.scatter(df[50:100]['sepal length'], df[50:100]['sepal width'], label='1')\n",
    "plt.plot(test_point[0], test_point[1], 'bo', label='test_point')\n",
    "plt.xlabel('sepal length')\n",
    "plt.ylabel('sepal width')\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### scikit-learn实例"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.neighbors import KNeighborsClassifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "KNeighborsClassifier()"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "clf_sk = KNeighborsClassifier()\n",
    "clf_sk.fit(X_train, y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1.0"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "clf_sk.score(X_test, y_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "### sklearn.neighbors.KNeighborsClassifier\n",
    "\n",
    "- n_neighbors: 临近点个数\n",
    "- p: 距离度量\n",
    "- algorithm: 近邻算法，可选{'auto', 'ball_tree', 'kd_tree', 'brute'}\n",
    "- weights: 确定近邻的权重"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# kd树"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "k近邻法最简单的实现方法是线性扫描，这时要计算输入实例与每一个训练实例的距离，当训练集很大时，计算非常耗时，为了提高k近邻搜索的效率，可以考虑使用特殊的结构存储训练数据，以减少计算距离的次数，kd树就是其中的一种方法\n",
    "**kd**树是一种对k维空间中的实例点进行存储以便对其进行快速检索的树形数据结构。\n",
    "\n",
    "**kd**树是二叉树，表示对$k$维空间的一个划分（partition）。构造**kd**树相当于不断地用垂直于坐标轴的超平面将$k$维空间切分，构成一系列的k维超矩形区域。kd树的每个结点对应于一个$k$维超矩形区域。\n",
    "\n",
    "构造**kd**树的方法如下：\n",
    "\n",
    "构造根结点，使根结点对应于$k$维空间中包含所有实例点的超矩形区域；通过下面的递归方法，不断地对$k$维空间进行切分，生成子结点。在超矩形区域（结点）上选择一个坐标轴和在此坐标轴上的一个切分点，确定一个超平面，这个超平面通过选定的切分点并垂直于选定的坐标轴，将当前超矩形区域切分为左右两个子区域\n",
    "（子结点）；这时，实例被分到两个子区域。这个过程直到子区域内没有实例时终止（终止时的结点为叶结点）。在此过程中，将实例保存在相应的结点上。\n",
    "\n",
    "通常，依次选择坐标轴对空间切分，选择训练实例点在选定坐标轴上的中位数\n",
    "（median）为切分点，这样得到的**kd**树是平衡的。注意，平衡的**kd**树搜索时的效率未必是最优的。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 构造平衡kd树算法\n",
    "输入：$k$维空间数据集$T＝\\{x_1，x_2,…,x_N\\}$，\n",
    "\n",
    "其中$x_{i}=\\left(x_{i}^{(1)}, x_{i}^{(2)}, \\cdots, x_{i}^{(k)}\\right)^{\\mathrm{T}}$ ，$i＝1,2,…,N$；\n",
    "\n",
    "输出：**kd**树。\n",
    "\n",
    "（1）开始：构造根结点，根结点对应于包含$T$的$k$维空间的超矩形区域。\n",
    "\n",
    "选择$x^{(1)}$为坐标轴，以T中所有实例的$x^{(1)}$坐标的中位数为切分点，将根结点对应的超矩形区域切分为两个子区域。切分由通过切分点并与坐标轴$x^{(1)}$垂直的超平面实现。\n",
    "\n",
    "由根结点生成深度为1的左、右子结点：左子结点对应坐标$x^{(1)}$小于切分点的子区域， 右子结点对应于坐标$x^{(1)}$大于切分点的子区域。\n",
    "\n",
    "将落在切分超平面上的实例点保存在根结点。\n",
    "\n",
    "（2）重复：对深度为$j$的结点，选择$x^{(1)}$为切分的坐标轴，$l＝j(modk)+1$，以该结点的区域中所有实例的$x^{(1)}$坐标的中位数为切分点，将该结点对应的超矩形区域切分为两个子区域。切分由通过切分点并与坐标轴$x^{(1)}$垂直的超平面实现。\n",
    "\n",
    "由该结点生成深度为$j+1$的左、右子结点：左子结点对应坐标$x^{(1)}$小于切分点的子区域，右子结点对应坐标$x^{(1)}$大于切分点的子区域。\n",
    "\n",
    "将落在切分超平面上的实例点保存在该结点。\n",
    "\n",
    "（3）直到两个子区域没有实例存在时停止。从而形成**kd**树的区域划分。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1.构建kd树"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "# kd-tree每个结点中主要包含的数据结构如下\n",
    "class KdNode(object):\n",
    "    def __init__(self, dom_elt, split, left, right):\n",
    "        self.dom_elt = dom_elt  # k维向量节点(k维空间中的一个样本点)\n",
    "        self.split = split      # 整数（进行分割维度的序号）\n",
    "        self.left = left        # 该结点分割超平面左子空间构成的kd-tree\n",
    "        self.right = right      # 该结点分割超平面右子空间构成的kd-tree\n",
    "\n",
    "# 构建kd树\n",
    "class KdTree(object):\n",
    "    def __init__(self, data):\n",
    "        k = len(data[0])  # 数据维度\n",
    "\n",
    "        # 创建结点\n",
    "        def CreateNode(split, data_set):  # 按第split维划分数据集data_set创建KdNode\n",
    "            if not data_set:  # 数据集为空\n",
    "                return None\n",
    "            # key参数的值为一个函数，此函数只有一个参数且返回一个值用来进行比较\n",
    "            # operator模块提供的itemgetter函数用于获取对象的哪些维的数据，参数为需要获取的数据在对象中的序号\n",
    "            #data_set.sort(key=itemgetter(split)) # 按要进行分割的那一维数据排序\n",
    "            data_set.sort(key=lambda x: x[split])   # 将结点按照第split维进行排序\n",
    "            split_pos = len(data_set) // 2          # //为Python中的整数除法\n",
    "            median = data_set[split_pos]            # 中位数分割点\n",
    "            split_next = (split + 1) % k            # 下一次进行分割的维度\n",
    "\n",
    "            # 递归的创建kd树\n",
    "            return KdNode(\n",
    "                median,\n",
    "                split,\n",
    "                CreateNode(split_next, data_set[:split_pos]),      # 创建左子树\n",
    "                CreateNode(split_next, data_set[split_pos + 1:]))  # 创建右子树\n",
    "\n",
    "        self.root = CreateNode(0, data)  # 从第0维分量开始构建kd树,返回根节点\n",
    "\n",
    "\n",
    "# KDTree的前序遍历\n",
    "def preOrder(root):\n",
    "    print(root.dom_elt)\n",
    "    if root.left:  # 节点不为空\n",
    "        preOrder(root.left)\n",
    "    if root.right:\n",
    "        preOrder(root.right)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2.搜索kd树"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 对构建好的kd树进行搜索，寻找与目标点最近的样本点：\n",
    "from math import sqrt\n",
    "from collections import namedtuple\n",
    "\n",
    "# 定义一个namedtuple,分别存放最近坐标点、最近距离和访问过的节点数\n",
    "result = namedtuple(\"Result_tuple\",\"nearest_point  nearest_dist  nodes_visited\")\n",
    "\n",
    "# 搜索kd树，找出与point距离最近的点\n",
    "def find_nearest(tree, point):\n",
    "    k = len(point)  # 数据维度\n",
    "\n",
    "    def travel(kd_node, target, max_dist):\n",
    "        if kd_node is None:\n",
    "            return result([0] * k, float(\"inf\"),0)  # python中用float(\"inf\")和float(\"-inf\")表示正负无穷\n",
    "\n",
    "        nodes_visited = 1\n",
    "\n",
    "        s = kd_node.split        # 进行分割的维度\n",
    "        pivot = kd_node.dom_elt  # 进行分割的“轴”\n",
    "\n",
    "        #-----------------------------------------------------------------------------------\n",
    "        #寻找point所属区域对应的叶结点\n",
    "        if target[s] <= pivot[s]:         # 如果目标点第s维小于分割轴的对应值(目标离左子树更近)\n",
    "            nearer_node = kd_node.left    # 下一个访问节点为左子树根节点\n",
    "            further_node = kd_node.right  # 同时记录下右子树\n",
    "        else:                             # 目标离右子树更近\n",
    "            nearer_node = kd_node.right   # 下一个访问节点为右子树根节点\n",
    "            further_node = kd_node.left\n",
    "\n",
    "        temp1 = travel(nearer_node, target, max_dist)  # 进行遍历找到包含目标点的区域\n",
    "\n",
    "        #-------------------------------------------------------------------------------------\n",
    "        #以此叶节点作为“当前最近点”\n",
    "        nearest = temp1.nearest_point  \n",
    "        dist = temp1.nearest_dist  # 更新最近距离\n",
    "\n",
    "        nodes_visited += temp1.nodes_visited\n",
    "\n",
    "        if dist < max_dist:\n",
    "            max_dist = dist       # 最近点将在以目标点为球心，max_dist为半径的超球体内\n",
    "\n",
    "        temp_dist = abs(pivot[s] - target[s])            # 第s维上目标点与分割超平面的距离\n",
    "        if max_dist < temp_dist:                         # 判断超球体是否与超平面相交\n",
    "            return result(nearest, dist, nodes_visited)  # 不相交则可以直接返回，不用继续判断\n",
    "\n",
    "        #----------------------------------------------------------------------\n",
    "        # 计算目标点与分割点的欧氏距离\n",
    "        temp_dist = sqrt(sum((p1 - p2)**2 for p1, p2 in zip(pivot, target)))\n",
    "\n",
    "        if temp_dist < dist:   # 如果“更近”\n",
    "            nearest = pivot    # 更新最近点\n",
    "            dist = temp_dist   # 更新最近距离\n",
    "            max_dist = dist    # 更新超球体半径\n",
    "\n",
    "        # 检查另一个子结点对应的区域是否有更近的点\n",
    "        temp2 = travel(further_node, target, max_dist)\n",
    "\n",
    "        nodes_visited += temp2.nodes_visited\n",
    "        if temp2.nearest_dist < dist:         # 如果另一个子结点内存在更近距离\n",
    "            nearest = temp2.nearest_point     # 更新最近点\n",
    "            dist = temp2.nearest_dist         # 更新最近距离\n",
    "\n",
    "        return result(nearest, dist, nodes_visited)\n",
    "\n",
    "    return travel(tree.root, point, float(\"inf\"))  # 从根节点开始递归"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 例3.2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[7, 2]\n",
      "[5, 4]\n",
      "[2, 3]\n",
      "[4, 7]\n",
      "[9, 6]\n",
      "[8, 1]\n"
     ]
    }
   ],
   "source": [
    "# 构造kd树\n",
    "data = [[2,3],[5,4],[9,6],[4,7],[8,1],[7,2]]\n",
    "kd = KdTree(data)\n",
    "preOrder(kd.root)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "from time import process_time\n",
    "from random import random\n",
    "\n",
    "# 产生一个k维随机向量，每维分量值在0~1之间\n",
    "def random_point(k):\n",
    "    return [random() for _ in range(k)]\n",
    " \n",
    "# 产生n个k维随机向量 \n",
    "def random_points(k, n):\n",
    "    return [random_point(k) for _ in range(n)]     "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Result_tuple(nearest_point=[2, 3], nearest_dist=1.8027756377319946, nodes_visited=4)\n"
     ]
    }
   ],
   "source": [
    "# 搜索kd树\n",
    "ret = find_nearest(kd, [3,4.5])\n",
    "print (ret)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time:  7.1875 s\n",
      "Result_tuple(nearest_point=[0.09674566867259649, 0.5054271513253492, 0.8014541012628393], nearest_dist=0.006493000414242284, nodes_visited=56)\n"
     ]
    }
   ],
   "source": [
    "N = 400000\n",
    "t0 = process_time()\n",
    "kd2 = KdTree(random_points(3, N))            # 构建包含四十万个3维空间样本点的kd树\n",
    "ret2 = find_nearest(kd2, [0.1,0.5,0.8])      # 四十万个样本点中寻找离目标最近的点\n",
    "t1 = process_time()\n",
    "print (\"time: \",t1-t0, \"s\")\n",
    "print (ret2)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
